# Copyright (c) Microsoft. All rights reserved.

# Licensed under the MIT license. See LICENSE.md file in the project root
# for full license information.
# ==============================================================================

import argparse
import numpy as np
import sys
import os
import cntk as C
from mmdnn.conversion.examples.imagenet_test import TestKit

class TestCNTK(TestKit):

    def __init__(self):
        super(TestCNTK, self).__init__()

        self.truth['mxnet']['inception_bn'] = [(21, 0.84820729), (144, 0.06263639), (677, 0.015408826), (973, 0.014532777), (562, 0.0053690737)]

        self.truth['keras']['resnet'] = [(144, 0.77398175), (23, 0.10650793), (21, 0.081077583), (146, 0.0092755388), (562, 0.0089645367)]
        self.truth['tensorflow']['resnet'] = [(22, 13.370872), (147, 8.8040094), (24, 5.6983061), (90, 5.6143088), (95, 4.8060427)]

        self.model = self.MainModel.KitModel(self.args.w)
        # self.model, self.testop = self.MainModel.KitModel(self.args.w)


    def preprocess(self, image_path):
        self.data = super(TestCNTK, self).preprocess(image_path)


    def print_result(self):
        predict = self.model.eval({self.model.arguments[0]:[self.data]})
        super(TestCNTK, self).print_result(predict)


    def print_intermediate_result(self, layer_name, if_transpose = False):
        test_arr = self.testop.eval({self.testop.arguments[0]:[self.data]})
        super(TestCNTK, self).print_intermediate_result(test_arr, if_transpose)


    def inference(self, image_path):
        self.preprocess(image_path)

        # self.print_intermediate_result(None, False)

        self.print_result()

        self.test_truth()

    def dump(self, path = None):
        if path is None: path = self.args.dump
        self.model.save(path)
        print ('CNTK model file is saved as [{}], generated by [{}.py] and [{}].'.format(
            path, self.args.n, self.args.w))

    def detect(self, image_path, path = None):
        self.preprocess(image_path)
        print("Found {} outputs".format(len(self.model)))
        for output in self.model:
            predict = output.eval({output.arguments[0]:[self.data/255.]})
            predict.dump("finalconv_{}.npy".format(str(predict.shape[1])))
            print ('The output of CNTK model file is saved as [finalconv_{}.npy].'.format(
            str(predict.shape[1])))

        print('generated by [{}.py], [{}] and [{}].'.format(self.args.n, self.args.w, image_path))

if __name__=='__main__':
    tester = TestCNTK()
    if tester.args.dump:
        tester.dump()
    elif tester.args.detect:
        tester.detect(tester.args.image, tester.args.detect)
    else:
        tester.inference(tester.args.image)
